{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Use Close-Form Policy to Play PongNoFrameskip-v4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import logging\n",
    "import imp\n",
    "import itertools\n",
    "\n",
    "import numpy as np\n",
    "np.random.seed(0)\n",
    "import gym\n",
    "\n",
    "imp.reload(logging)\n",
    "logging.basicConfig(level=logging.DEBUG,\n",
    "        format='%(asctime)s [%(levelname)s] %(message)s',\n",
    "        stream=sys.stdout, datefmt='%H:%M:%S')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "16:58:43 [INFO] env: <AtariEnv<PongNoFrameskip-v4>>\n",
      "16:58:43 [INFO] action_space: Discrete(6)\n",
      "16:58:43 [INFO] observation_space: Box(0, 255, (210, 160, 3), uint8)\n",
      "16:58:43 [INFO] reward_range: (-inf, inf)\n",
      "16:58:43 [INFO] metadata: {'render.modes': ['human', 'rgb_array']}\n",
      "16:58:43 [INFO] _max_episode_steps: 400000\n",
      "16:58:43 [INFO] _elapsed_steps: None\n"
     ]
    }
   ],
   "source": [
    "env = gym.make('PongNoFrameskip-v4')\n",
    "env.seed(0)\n",
    "for key in vars(env):\n",
    "    logging.info('%s: %s', key, vars(env)[key])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CloseFormAgent:\n",
    "    def __init__(self, _):\n",
    "        pass\n",
    "\n",
    "    def reset(self, mode=None):\n",
    "        pass\n",
    "\n",
    "    def obs2loc(self, observation):\n",
    "        colors = {'racket': 92, 'ball': 236}\n",
    "        heights = {'racket': 16, 'ball': 4}\n",
    "        ymin, ymax = 34, 193\n",
    "        locations = {}\n",
    "        for obj in colors:\n",
    "            match = observation[ymin:ymax, :, 0] == colors[obj]\n",
    "            xx = np.where(match.any(axis=0))[0]\n",
    "            yy = np.where(match.any(axis=1))[0]\n",
    "            if yy.size and yy.min() == 0:\n",
    "                yy = np.arange(yy.max()-heights[obj]+1, yy.max()+1, 1)\n",
    "            if yy.size and yy.max() == ymax - ymin:\n",
    "                yy = np.arange(yy.min(), yy.min()+heights[obj], 1)\n",
    "            locations[obj + 'x'] = xx.mean() if xx.size else np.nan\n",
    "            locations[obj + 'y'] = yy.mean() if yy.size else np.nan\n",
    "        return locations\n",
    "\n",
    "    def step(self, observation, _reward, _done):\n",
    "        locations = self.obs2loc(observation)\n",
    "        if locations['bally'] < locations['rackety']:\n",
    "            action = 2 # move up\n",
    "        elif locations['bally'] > locations['rackety']:\n",
    "            action = 3 # move down\n",
    "        else:\n",
    "            action = 0\n",
    "        return action\n",
    "\n",
    "    def close(self):\n",
    "        pass\n",
    "\n",
    "\n",
    "agent = CloseFormAgent(env)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "16:58:43 [INFO] ==== test ====\n",
      "16:59:18 [DEBUG] test episode 0: reward = 21.00, steps = 27445\n",
      "16:59:52 [DEBUG] test episode 1: reward = 21.00, steps = 27445\n",
      "17:00:27 [DEBUG] test episode 2: reward = 21.00, steps = 27445\n",
      "17:01:01 [DEBUG] test episode 3: reward = 21.00, steps = 27445\n",
      "17:01:36 [DEBUG] test episode 4: reward = 21.00, steps = 27445\n",
      "17:02:10 [DEBUG] test episode 5: reward = 21.00, steps = 27445\n",
      "17:02:44 [DEBUG] test episode 6: reward = 21.00, steps = 27445\n",
      "17:03:18 [DEBUG] test episode 7: reward = 21.00, steps = 27445\n",
      "17:03:52 [DEBUG] test episode 8: reward = 21.00, steps = 27445\n",
      "17:04:27 [DEBUG] test episode 9: reward = 21.00, steps = 27445\n",
      "17:05:02 [DEBUG] test episode 10: reward = 21.00, steps = 27445\n",
      "17:05:35 [DEBUG] test episode 11: reward = 21.00, steps = 27445\n",
      "17:06:11 [DEBUG] test episode 12: reward = 21.00, steps = 27445\n",
      "17:06:45 [DEBUG] test episode 13: reward = 21.00, steps = 27445\n",
      "17:07:20 [DEBUG] test episode 14: reward = 21.00, steps = 27445\n",
      "17:07:54 [DEBUG] test episode 15: reward = 21.00, steps = 27445\n",
      "17:08:29 [DEBUG] test episode 16: reward = 21.00, steps = 27445\n",
      "17:09:03 [DEBUG] test episode 17: reward = 21.00, steps = 27445\n",
      "17:09:36 [DEBUG] test episode 18: reward = 21.00, steps = 27445\n",
      "17:10:08 [DEBUG] test episode 19: reward = 21.00, steps = 27445\n",
      "17:10:41 [DEBUG] test episode 20: reward = 21.00, steps = 27445\n",
      "17:11:13 [DEBUG] test episode 21: reward = 21.00, steps = 27445\n",
      "17:11:46 [DEBUG] test episode 22: reward = 21.00, steps = 27445\n",
      "17:12:18 [DEBUG] test episode 23: reward = 21.00, steps = 27445\n",
      "17:12:50 [DEBUG] test episode 24: reward = 21.00, steps = 27445\n",
      "17:13:23 [DEBUG] test episode 25: reward = 21.00, steps = 27445\n",
      "17:13:55 [DEBUG] test episode 26: reward = 21.00, steps = 27445\n",
      "17:14:28 [DEBUG] test episode 27: reward = 21.00, steps = 27445\n",
      "17:15:01 [DEBUG] test episode 28: reward = 21.00, steps = 27445\n",
      "17:15:33 [DEBUG] test episode 29: reward = 21.00, steps = 27445\n",
      "17:16:05 [DEBUG] test episode 30: reward = 21.00, steps = 27445\n",
      "17:16:38 [DEBUG] test episode 31: reward = 21.00, steps = 27445\n",
      "17:17:10 [DEBUG] test episode 32: reward = 21.00, steps = 27445\n",
      "17:17:42 [DEBUG] test episode 33: reward = 21.00, steps = 27445\n",
      "17:18:14 [DEBUG] test episode 34: reward = 21.00, steps = 27445\n",
      "17:18:46 [DEBUG] test episode 35: reward = 21.00, steps = 27445\n",
      "17:19:19 [DEBUG] test episode 36: reward = 21.00, steps = 27445\n",
      "17:19:51 [DEBUG] test episode 37: reward = 21.00, steps = 27445\n",
      "17:20:24 [DEBUG] test episode 38: reward = 21.00, steps = 27445\n",
      "17:20:58 [DEBUG] test episode 39: reward = 21.00, steps = 27445\n",
      "17:21:30 [DEBUG] test episode 40: reward = 21.00, steps = 27445\n",
      "17:22:02 [DEBUG] test episode 41: reward = 21.00, steps = 27445\n",
      "17:22:34 [DEBUG] test episode 42: reward = 21.00, steps = 27445\n",
      "17:23:06 [DEBUG] test episode 43: reward = 21.00, steps = 27445\n",
      "17:23:39 [DEBUG] test episode 44: reward = 21.00, steps = 27445\n",
      "17:24:11 [DEBUG] test episode 45: reward = 21.00, steps = 27445\n",
      "17:24:43 [DEBUG] test episode 46: reward = 21.00, steps = 27445\n",
      "17:25:15 [DEBUG] test episode 47: reward = 21.00, steps = 27445\n",
      "17:25:48 [DEBUG] test episode 48: reward = 21.00, steps = 27445\n",
      "17:26:20 [DEBUG] test episode 49: reward = 21.00, steps = 27445\n",
      "17:26:52 [DEBUG] test episode 50: reward = 21.00, steps = 27445\n",
      "17:27:24 [DEBUG] test episode 51: reward = 21.00, steps = 27445\n",
      "17:27:56 [DEBUG] test episode 52: reward = 21.00, steps = 27445\n",
      "17:28:28 [DEBUG] test episode 53: reward = 21.00, steps = 27445\n",
      "17:29:01 [DEBUG] test episode 54: reward = 21.00, steps = 27445\n",
      "17:29:34 [DEBUG] test episode 55: reward = 21.00, steps = 27445\n",
      "17:30:07 [DEBUG] test episode 56: reward = 21.00, steps = 27445\n",
      "17:30:39 [DEBUG] test episode 57: reward = 21.00, steps = 27445\n",
      "17:31:12 [DEBUG] test episode 58: reward = 21.00, steps = 27445\n",
      "17:31:44 [DEBUG] test episode 59: reward = 21.00, steps = 27445\n",
      "17:32:17 [DEBUG] test episode 60: reward = 21.00, steps = 27445\n",
      "17:32:49 [DEBUG] test episode 61: reward = 21.00, steps = 27445\n",
      "17:33:22 [DEBUG] test episode 62: reward = 21.00, steps = 27445\n",
      "17:33:54 [DEBUG] test episode 63: reward = 21.00, steps = 27445\n",
      "17:34:26 [DEBUG] test episode 64: reward = 21.00, steps = 27445\n",
      "17:34:58 [DEBUG] test episode 65: reward = 21.00, steps = 27445\n",
      "17:35:31 [DEBUG] test episode 66: reward = 21.00, steps = 27445\n",
      "17:36:03 [DEBUG] test episode 67: reward = 21.00, steps = 27445\n",
      "17:36:36 [DEBUG] test episode 68: reward = 21.00, steps = 27445\n",
      "17:37:08 [DEBUG] test episode 69: reward = 21.00, steps = 27445\n",
      "17:37:40 [DEBUG] test episode 70: reward = 21.00, steps = 27445\n",
      "17:38:13 [DEBUG] test episode 71: reward = 21.00, steps = 27445\n",
      "17:38:45 [DEBUG] test episode 72: reward = 21.00, steps = 27445\n",
      "17:39:17 [DEBUG] test episode 73: reward = 21.00, steps = 27445\n",
      "17:39:50 [DEBUG] test episode 74: reward = 21.00, steps = 27445\n",
      "17:40:22 [DEBUG] test episode 75: reward = 21.00, steps = 27445\n",
      "17:40:54 [DEBUG] test episode 76: reward = 21.00, steps = 27445\n",
      "17:41:27 [DEBUG] test episode 77: reward = 21.00, steps = 27445\n",
      "17:41:59 [DEBUG] test episode 78: reward = 21.00, steps = 27445\n",
      "17:42:32 [DEBUG] test episode 79: reward = 21.00, steps = 27445\n",
      "17:43:04 [DEBUG] test episode 80: reward = 21.00, steps = 27445\n",
      "17:43:37 [DEBUG] test episode 81: reward = 21.00, steps = 27445\n",
      "17:44:10 [DEBUG] test episode 82: reward = 21.00, steps = 27445\n",
      "17:44:42 [DEBUG] test episode 83: reward = 21.00, steps = 27445\n",
      "17:45:15 [DEBUG] test episode 84: reward = 21.00, steps = 27445\n",
      "17:45:47 [DEBUG] test episode 85: reward = 21.00, steps = 27445\n",
      "17:46:20 [DEBUG] test episode 86: reward = 21.00, steps = 27445\n",
      "17:46:52 [DEBUG] test episode 87: reward = 21.00, steps = 27445\n",
      "17:47:24 [DEBUG] test episode 88: reward = 21.00, steps = 27445\n",
      "17:47:56 [DEBUG] test episode 89: reward = 21.00, steps = 27445\n",
      "17:48:29 [DEBUG] test episode 90: reward = 21.00, steps = 27445\n",
      "17:49:01 [DEBUG] test episode 91: reward = 21.00, steps = 27445\n",
      "17:49:33 [DEBUG] test episode 92: reward = 21.00, steps = 27445\n",
      "17:50:05 [DEBUG] test episode 93: reward = 21.00, steps = 27445\n",
      "17:50:37 [DEBUG] test episode 94: reward = 21.00, steps = 27445\n",
      "17:51:10 [DEBUG] test episode 95: reward = 21.00, steps = 27445\n",
      "17:51:42 [DEBUG] test episode 96: reward = 21.00, steps = 27445\n",
      "17:52:15 [DEBUG] test episode 97: reward = 21.00, steps = 27445\n",
      "17:52:47 [DEBUG] test episode 98: reward = 21.00, steps = 27445\n",
      "17:53:20 [DEBUG] test episode 99: reward = 21.00, steps = 27445\n",
      "17:53:20 [INFO] average episode reward = 21.00 ± 0.00\n"
     ]
    }
   ],
   "source": [
    "def play_episode(env, agent, max_episode_steps=None, mode=None, render=False):\n",
    "    observation, reward, done = env.reset(), 0., False\n",
    "    agent.reset(mode=mode)\n",
    "    episode_reward, elapsed_steps = 0., 0\n",
    "    while True:\n",
    "        action = agent.step(observation, reward, done)\n",
    "        if render:\n",
    "            env.render()\n",
    "        if done:\n",
    "            break\n",
    "        observation, reward, done, _ = env.step(action)\n",
    "        episode_reward += reward\n",
    "        elapsed_steps += 1\n",
    "        if max_episode_steps and elapsed_steps >= max_episode_steps:\n",
    "            break\n",
    "    agent.close()\n",
    "    return episode_reward, elapsed_steps\n",
    "\n",
    "\n",
    "logging.info('==== test ====')\n",
    "episode_rewards = []\n",
    "for episode in range(100):\n",
    "    episode_reward, elapsed_steps = play_episode(env, agent)\n",
    "    episode_rewards.append(episode_reward)\n",
    "    logging.debug('test episode %d: reward = %.2f, steps = %d',\n",
    "            episode, episode_reward, elapsed_steps)\n",
    "logging.info('average episode reward = %.2f ± %.2f',\n",
    "        np.mean(episode_rewards), np.std(episode_rewards))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "env.close()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
